#### US simulation function
my.sim = function(shock.date, price.path.sim, lm.regs.sim, 
                  stp.den, base.spuds, 
                  ex.well.liq, ex.well.gas, 
                  roy.base, sim.start, profiles.pct.liq.class, profiles.pct.gas.class,
                  ip.liq.class, ip.gas.class,
                  spudcounts.class,
                  use.spudreg.trend = FALSE,
                  price.floor=0.1, # $1/bbl, or $1/6 = $0.167/mmbtu
                  roy.fed.new.on = NULL, # new federal onshore royalty rate. if null, uses baseline. 
                  roy.fed.new.off = NULL, # new federal offshore royalty rate. if null, uses baseline. 
                  fed.leasing.ban = FALSE, ban.lag=0, # lag in ban taking effect, in years
                  onshore.only=FALSE, # if TRUE, apply policy only to onshore. otherwise, applies to both onshore and offshore
                  carb.add.fed.oil = 0, # in $/tCO2e
                  carb.add.fed.gas = 0,
                  carb.add.nonfed.oil = 0,
                  carb.add.nonfed.gas = 0,
                  carb.add.growth.rate = 0.02,
                  oil.emis.rate.fed = 0.43, oil.emis.rate.nonfed = 0.43, # tCO2/bbl
                  gas.emis.rate.fed = (117+28.55)/2204.62, gas.emis.rate.nonfed = (117+28.55)/2204.62,
                  lease.length=10) { # tCO2e/mcf, including 2.36% methane leakage at 100 year GWP=34 (Raimi email on 5/4, 10:35am)
  # tt...length of simulation
  # st.yr...year to get average inputs (e.g., baseline spuds, prices) for the simulation
  # iv.elasts...list of vectors containing oil price coefficients. 1=conv, 2=unconv,
  # spudcounts.wide...data table containing base spuds by type (N_conv, N_unconv)
  # profiles...matrix containing average production profiles (will be scaled so IP (row 2)
  #      is consistent with average IP in 'st.yr' from 'spudcounts.wide')
  # price.shock.pct...% change in oil prices in simulation (e.g. 0.5=50%)
  # yr.fes...fixed effects in hazard regressions, as Tx2 matrix, (conv, unconv)
  # (oil.mr.coef, gas.mr.coef, sigma, kappa)...coefficients from hazards with just revenue 
  #      covariates - From Stata output conventional in first element, unconventional in 2nd, 
  # well.classes = dimnames(lm.regs.sim)
  well.classes = c(dimnames(lm.regs.sim), list(c("Covered","Uncovered")))
  well.class.grid = expand.grid(well.classes)
  well.class.string = apply(as.matrix(well.class.grid), 1, paste, collapse='_')
  
  profile.names = c(colnames(profiles.pct.liq.class)%&%'_Covered',
                    colnames(profiles.pct.liq.class)%&%'_Uncovered')
  
  profiles.pct.liq.class = cbind(profiles.pct.liq.class, profiles.pct.liq.class)
  profiles.pct.gas.class = cbind(profiles.pct.gas.class, profiles.pct.gas.class)
  ip.liq.class = cbind(ip.liq.class, ip.liq.class)
  ip.gas.class = cbind(ip.gas.class, ip.gas.class)
  
  colnames(profiles.pct.liq.class) <- colnames(profiles.pct.gas.class) <- 
    colnames(ip.liq.class) <- colnames(ip.gas.class) <- profile.names
  
  stp.den = array(stp.den, dim=c(dim(stp.den), 2),
                  dimnames=c(list(1:dim(stp.den)[1]), well.classes))
  base.spuds = array(base.spuds, dim=c(dim(base.spuds), 2),
                     dimnames=well.classes)
  roy.base = array(roy.base, dim=c(dim(roy.base), 2),
                   dimnames=well.classes)
  
  if (!identical(well.class.string, colnames(profiles.pct.liq.class))) stop('Class mismatch: prof.pct.liq')
  if (!identical(well.class.string, colnames(profiles.pct.gas.class))) stop('Class mismatch: prof.pct.gas')
  if (!identical(well.class.string, colnames(ip.liq.class))) stop('Class mismatch: ip.liq.class')
  if (!identical(well.class.string, colnames(ip.gas.class))) stop('Class mismatch: ip.gas.class')
  
  if (!identical(well.classes, dimnames(stp.den)[-1])) stop('Class mismatch: stp')
  if (!identical(well.classes, dimnames(base.spuds))) stop('Class mismatch: base.spuds')
  if (!identical(well.classes, dimnames(roy.base))) stop('Class mismatch: roy.base')
  
  output <- array(list(copy(price.path.sim)), 
                  dim=sapply(well.classes, length),
                  dimnames=well.classes)
  
  fed.dimension = which(sapply(well.classes, function(x) which(x=='Nonfederal'))==1)
  if (fed.dimension!=2) stop('Error: Federal dimension is not 2')
  dimnames(output)
  
  # Loop through all cells of output, determining net oil and gas prices
  # after royalties and carbon adders.
  for (i in 1:nrow(well.class.grid)) {
    
    # Make a copy of the data.table
    # If you don't do this, and instead try to modify output[[i]] by reference,
    # it will change ALL data.tables in output, not just the current one.
    # Making a copy makes sure you're only modifying this particular one.
    price.temp = copy(output[[i]])
    price.temp = price.temp[date>=sim.start.date-months(1+lag.length)]
    
    fed.status = well.class.grid[i,2]
    covered.status = well.class.grid[i,4]
    
    if (well.class.grid[i,3]=='Offshore') {
      roy.fed.new.temp = ifelse(!is.null(roy.fed.new.off), roy.fed.new.off, roy.base[[i]])    
    } else if (well.class.grid[i,3]=='Onshore') {
      roy.fed.new.temp = ifelse(!is.null(roy.fed.new.on), roy.fed.new.on, roy.base[[i]])    
    } 
    
    # Adjust WTI prices for carbon adder and change in royalty rate
    # New royalty rate
    roy.new = ifelse(fed.status=='Federal' & covered.status=='Covered', 
                     roy.fed.new.temp, roy.base[[i]])
    
    # Carbon adder for oil, in WTI ($/bbl) equivalent.
    # These are $/tCO2 * tCO2/barrel = $/barrel equivalent
    carb.add.wti = ifelse(fed.status=='Federal' & covered.status=='Covered', 
                          carb.add.fed.oil * oil.emis.rate.fed, 
                          carb.add.nonfed.oil * oil.emis.rate.nonfed)
    # Carbon adder for gas, in Henry Hub ($/mmbtu) equivalent.
    # These are $/tCO2 * 0.066 tCO2/mcf / (1.036 mmbtu/mcf) = $/mmbtu equivalent
    carb.add.hh = ifelse(fed.status=='Federal' & covered.status=='Covered', 
                         carb.add.fed.gas * gas.emis.rate.fed, 
                         carb.add.nonfed.gas * gas.emis.rate.nonfed)/1.036 # 1.036 mmbtu/mcf 
    
    if (well.class.grid[i,3]=='Offshore' & onshore.only==TRUE) {
      carb.add.wti = 0
      carb.add.hh = 0
    }
    
    ## Before the policy change, use the baseline royalty rate
    ## After policy change, use the new royalty rate and carbon adder
    price.temp[date<shock.date, roy:= roy.base[[i]]]
    price.temp[date>=shock.date, roy:= roy.new]
    
    # Adjust oil price
    price.temp[date<shock.date, carb.add.per.bbl := 0]
    price.temp[date>=shock.date, carb.add.per.bbl := carb.add.wti*(1+carb.add.growth.rate/12)^(elapsed_months(date, shock.date))]
    price.temp[, net.crude.price := wti*(1-roy) - carb.add.per.bbl]
    price.temp[, net.crude.price := pmax(net.crude.price, 0 + price.floor)] # if carbon adder>driller's share of price, net price could go negative
    
    # Adjust gas price
    price.temp[date<shock.date, carb.add.per.mmbtu := 0]
    price.temp[date>=shock.date, carb.add.per.mmbtu := carb.add.hh*(1+carb.add.growth.rate/12)^(elapsed_months(date, shock.date))]
    price.temp[, net.gas.price := hh*(1-roy) - carb.add.per.mmbtu]
    price.temp[, net.gas.price := pmax(net.gas.price, 0 + price.floor/6)] # if carbon adder>driller's share of price, net price could go negative
    
    price.temp[, d.ln.wti := c(NA, diff(log(net.crude.price)))]
    price.temp[, d.ln.hh := c(NA, diff(log(net.gas.price)))]
    
    # Save projected price changes
    output[[i]] = copy(price.temp)
  }
  
  delta.spud = data.table(mapply(FUN=function(fit, data) predict(fit, newdata=data), lm.regs.sim, output))
  delta.spud = cbind(price.path.sim[date>=sim.start.date-months(1+lag.length), date],
                     delta.spud)
  
  setnames(delta.spud, c('date',well.class.string))
  as.numeric(base.spuds)
  
  if (use.spudreg.trend==FALSE) {
    # subtract the intercept to avoid extrapolating the trend term out of sample.
    # in practice, the thing to subtract out is the mean(intercept+MOY) fe, so that
    # if all price terms are zero, then delta_spud=0 on average over the course of the year.
    # E.g., the nonfed onshore gas drilling regression has a trend of -0.28% per month.
    # If we project this over 30 years, that would cause a decline of exp(-0.0028*12*30)-1 = 64%
    # in long-run gas supply, simply due to the trend. Not realistic.
    reg.coefs = sapply(lm.regs.sim, function(x) x$coefficients)
    reg.coefs = cbind(reg.coefs, reg.coefs)
    colnames(reg.coefs) = well.class.string
    reg.coefs.int = reg.coefs[grep('Intercept', rownames(reg.coefs)),]
    reg.coefs.moy = reg.coefs[grep('moy', rownames(reg.coefs)),]
    # stack on implicit "zero" FE for January (otherwise it's not the right average)
    reg.coefs.moy = rbind(0, reg.coefs.moy)
    reg.coefs.adj = reg.coefs.int + apply(reg.coefs.moy, 2, mean)
  
    # Remove coefficient adjustment
    delta.spud.mat = as.matrix(delta.spud[,-'date'])
    delta.spud.mat = delta.spud.mat - rep.row(reg.coefs.adj, nrow(delta.spud.mat))
    nrow(delta.spud.mat)
    # Covnert back to a data.table
    delta.spud = cbind(delta.spud[,'date'], delta.spud.mat)
  }
  cum.delta.spud = delta.spud[,-'date'][, lapply(.SD, function(x) cumsum(na.omit(x)))]
  
  date.grid = delta.spud[complete.cases(delta.spud), unique(date)]
  
  base.spuds.grid = rep.row(as.numeric(base.spuds), nrow(cum.delta.spud))
  colnames(base.spuds.grid) = colnames(cum.delta.spud)
  
  share.spuds.covered = pmax(0, pmin(1, elapsed_months(date.grid, shock.date+years(lease.length))/(lease.length*12)+1))
  
  covered.mat = matrix(NA, nrow=nrow(base.spuds.grid), ncol=ncol(base.spuds.grid))
  colnames(covered.mat) = colnames(cum.delta.spud)
  covered.mat[,grep('Nonfederal.*_Covered', colnames(covered.mat))] = 0 # no nonfederal are covered
  covered.mat[,grep('Nonfederal.*_Uncovered', colnames(covered.mat))] = 1 # all nonfederal are uncovered
  covered.mat[,grep('Federal.*_Covered', colnames(covered.mat))] = share.spuds.covered # phase in coverage
  covered.mat[,grep('Federal.*_Uncovered', colnames(covered.mat))] = 1-share.spuds.covered # phase out uncovered
  
  base.spuds.grid = base.spuds.grid*covered.mat
  
  spuds = exp(log(base.spuds.grid) + cum.delta.spud)
  
  spuds = cbind(date=date.grid, spuds)
  
  ## Simulated Spud-to-Production Distributions ----------------------------------------------------
  ## Compute monthly number of wells beginning production using spud-to-prod distributions ----
  
  # For the first 24 months, some of the spud-to-prod distribution derives from
  # spuds that occurred before the price shock.
  max.dist.lags = max(which(stp.den>0, arr.ind=T)[,1])
  
  # Stack spudcounts.class from before sim.start.date on top of spuds
  if (spuds[, min(date)]!=sim.start.date) stop('Inconsistent start dates')
  if (!(spuds[, min(date)-months(1)] %in% spudcounts.class[, date])) stop("Don't have lagged modeled date")
  
  spudcounts.class.temp = cbind(spudcounts.class[,1], spudcounts.class[,-1], spudcounts.class[,-1])
  names(spudcounts.class.temp)[-1] = well.class.string
  spudcounts.class.temp[, (grep('Covered', names(spudcounts.class.temp), v=T)) := 0]
  
  spuds = rbind(spudcounts.class.temp[between(date, sim.start.date - months(max.dist.lags),
                                              spuds[, min(date) - 1])],
                spuds)
  
  # If modeling a full ban, set covered federal spuds to zero.
  if (fed.leasing.ban) {
    if (onshore.only==FALSE) fed.cols = grep('Federal.*Covered', names(spuds))
    if (onshore.only==TRUE) fed.cols = grep('Federal_Onshore_Covered', names(spuds))
    
    spuds[,c(fed.cols) := 0]
    
  }
  # Organize the distribution matrices
  stp.mat = matrix(stp.den, ncol=nrow(well.class.grid))
  colnames(stp.mat) = well.class.string
  
  # Reverse the stp distribution matrix so the most recent is on the bottom, as "shock.spuds" is sorted
  # from oldest (top) to newest (bottom).
  stp.mat.reverse = stp.mat[nrow(stp.mat):1, ] # alternatively, apply(shock.dists, 2, rev)
  
  prods = matrix(NA, nrow=nrow(spuds), ncol=nrow(well.class.grid))
  
  tt = spuds[date>=sim.start, .N]-1
  sim.start.idx = spuds[,which(date==sim.start)]
  for (t in 0:tt) {
    recent.spuds = spuds[between(date, sim.start+months(t-max.dist.lags+1), sim.start+months(t))]
    num.months = nrow(recent.spuds)
    
    # this is the previous, say, 31 months. Overlay with (reversed) spud-to-comp time for 
    # the same number of months. Last row is for most recent month, second last for previous, etc.
    stp.mat.reverse.temp = stp.mat.reverse[(nrow(stp.mat.reverse)-num.months+1):nrow(stp.mat.reverse),]
    
    new.prods = apply( stp.mat.reverse.temp*recent.spuds[,-'date'] , 2, sum)
    
    prods[sim.start.idx + t, ] = new.prods
    
    rm(new.prods)
  }
  # Zero out the NA values. These are dates before the sim.start.
  # There is production in these, but they'll be captured by the 
  # production from existing wells.
  if (!all(is.na(prods[spuds[,date<sim.start.date],]))) stop('Overwriting new prods')
  prods[spuds[,date<sim.start.date],] = 0
  dim(prods)
  
  ## Compute oil production over time ----
  # Now that we have the new productions over time, compute cumulative production over time
  # Baseline oil production is steady-state production, equal to the number of wells
  # coming online each period times the sum of the production profile.
  # Cut production profiles off after 10 years
  if (!all(colnames(profiles.pct.liq.class)==colnames(ip.liq.class))) stop('liq profile class mismatch')
  if (!all(colnames(profiles.pct.gas.class)==colnames(ip.gas.class))) stop('gas profile class mismatch')
  profiles.liq.class.scaled = profiles.pct.liq.class*rep.row(ip.liq.class, nrow(profiles.pct.liq.class))
  profiles.gas.class.scaled = profiles.pct.gas.class*rep.row(ip.gas.class, nrow(profiles.pct.gas.class))
  
  # Compute oil production
  oil.prod <- gas.prod <-  matrix(0, nrow=nrow(spuds), ncol=nrow(well.class.grid))
  
  if (nrow(profiles.liq.class.scaled)!=nrow(profiles.gas.class.scaled)) stop("Profiles have different lengths")
  
  for (t in 0:tt) {
    # Calculate oil production from wells coming online at time t
    
    # Production / day = (Production / well / day)*(Wells)
    incr.oil.prod = profiles.liq.class.scaled*rep.row(prods[sim.start.idx + t,], nrow(profiles.liq.class.scaled))
    incr.gas.prod = profiles.gas.class.scaled*rep.row(prods[sim.start.idx + t,], nrow(profiles.gas.class.scaled))
    
    # If stream of production is longer than the remaining simulation length, cut the extra
    # Profile rows start from the beginning of each well's life (1) and goes through the 
    # end of its life (nrow(incr.oil.prod)=nrow(profiles)) or the amount of time left
    # in the sim (nrow(oil.prod) minus today)
    prof.rows = 1:min(nrow(incr.oil.prod), (nrow(oil.prod) - (sim.start.idx+t) + 1)) # if we're on the last date, 1 row to fill
    # Production rows to increment start from "today" and runs for the length of production, or 
    # the the end of the sim, which ever comes first (hence the "min)
    prod.rows = (sim.start.idx+t):min((sim.start.idx+t+nrow(incr.oil.prod)-1) , nrow(oil.prod))
    
    oil.prod[prod.rows,] = oil.prod[prod.rows,] + incr.oil.prod[prof.rows,]
    gas.prod[prod.rows,] = gas.prod[prod.rows,] + incr.gas.prod[prof.rows,]
    
  }
  distinct.well.types = length(well.class.string)/2 # not including covered/uncovered
  distinct.well.strings = apply(as.matrix(well.class.grid[1:distinct.well.types,-4]), 1, paste, collapse='_')
  
  spuds.agg = spuds[,-1][, 1:distinct.well.types, with=F] + spuds[,-1][, (distinct.well.types + 1:distinct.well.types), with=F]
  spuds.agg = cbind(date=spuds[, date], data.table(spuds.agg))
  setnames(spuds.agg, c('date', distinct.well.strings))
  
  prods.agg = prods[, 1:distinct.well.types] + prods[, distinct.well.types + 1:distinct.well.types]
  prods.agg = cbind(date=spuds[, date], data.table(prods.agg))
  setnames(prods.agg, c('date', distinct.well.strings))
  
  # Total
  # Oil
  oil.prod.agg = (oil.prod[, 1:distinct.well.types] + oil.prod[, distinct.well.types + 1:distinct.well.types])/1e6
  oil.prod.agg = cbind(date=spuds[, date], data.table(oil.prod.agg))
  setnames(oil.prod.agg, c('date', distinct.well.strings))
  
  tot.oil.prod = apply(oil.prod/1e6, 1, sum)
  oil.prod.agg[, liq_Total := tot.oil.prod]
  
  # Gas
  gas.prod.agg = (gas.prod[, 1:distinct.well.types] + gas.prod[, distinct.well.types + 1:distinct.well.types])/1e6
  gas.prod.agg = cbind(date=spuds[, date], data.table(gas.prod.agg))
  setnames(gas.prod.agg, c('date', distinct.well.strings))
  
  tot.gas.prod = apply(gas.prod/1e6, 1, sum)
  gas.prod.agg[, gas_Total := tot.gas.prod]
  
  # Covered
  # Oil
  covered.cols = grep('Covered', well.class.string)
  oil.prod.covered = (oil.prod[, covered.cols])/1e6
  oil.prod.covered = cbind(date=spuds[, date], data.table(oil.prod.covered))
  setnames(oil.prod.covered, c('date', distinct.well.strings))
  
  oil.prod.covered[, liq_Total := apply(oil.prod.covered[,-1], 1, sum)]
  
  # Gas
  gas.prod.covered = (gas.prod[, covered.cols])/1e6
  gas.prod.covered = cbind(date=spuds[, date], data.table(gas.prod.covered))
  setnames(gas.prod.covered, c('date', distinct.well.strings))
  
  gas.prod.covered[, gas_Total := apply(gas.prod.covered[,-1], 1, sum)]
  
  dim(oil.prod)
  dim(oil.prod.agg)
  
  # Make sure new and existing production frames are the same size, for adding together
  min.date = oil.prod.agg[, min(date)]
  ex.well.liq = ex.well.liq[date>=min.date]
  ex.well.gas = ex.well.gas[date>=min.date]
  
  # Everything above was from new wells. Add in production from existing wells.
  if (any(names(oil.prod.agg)!=names(ex.well.liq))) stop("Existing liq prod and new well classes don't match")
  if (any(names(gas.prod.agg)!=names(ex.well.gas))) stop("Existing gas prod and new well classes don't match")
  if ( any( ex.well.liq[order(date), date] !=
            oil.prod.agg[date%in% ex.well.liq[, date]][order(date), date] ) ) stop('Existing and new well date mismatch (liq)')
  if ( any( ex.well.gas[order(date), date] !=
            gas.prod.agg[date%in% ex.well.gas[, date]][order(date), date] ) ) stop('Existing and new well date mismatch (gas)')
  
  # Add production from existing and new wells
  oil.prod.total = ex.well.liq[order(date), -'date'] + oil.prod.agg[date%in% ex.well.liq[, date]][order(date), -'date']
  oil.prod.total = cbind(date=oil.prod.agg[date%in% ex.well.liq[, date]][order(date), date], oil.prod.total)
  gas.prod.total = ex.well.gas[order(date), -'date'] + gas.prod.agg[date%in% ex.well.gas[, date]][order(date), -'date']
  gas.prod.total = cbind(date=gas.prod.agg[date%in% ex.well.gas[, date]][order(date), date], gas.prod.total)
  
  res = list()
  res[[1]] = oil.prod.total; names(res)[1] = 'OilProduction_AllWells'
  res[[2]] = gas.prod.total; names(res)[2] = 'GasProduction_AllWells'
  res[[3]] = ex.well.liq; names(res)[3] = 'OilProduction_ExWells'
  res[[4]] = ex.well.gas; names(res)[4] = 'GasProduction_ExWells'
  res[[5]] = oil.prod.agg; names(res)[5] = 'OilProduction_NewWells'
  res[[6]] = gas.prod.agg; names(res)[6] = 'GasProduction_NewWells'
  res[[7]] = spuds.agg; names(res)[7] = 'spuds'
  res[[8]] = prods.agg; names(res)[8] = 'prods'
  res[[9]] = output; names(res)[9] = 'PricePaths'
  res[[10]] = data.table(product=c('liq','gas'), rbind(ip.liq.class, ip.gas.class))
  names(res)[10] = 'IPs'
  res[[11]] = oil.prod.covered; names(res)[11] = 'OilProduction_NewWellsCovered'
  res[[12]] = gas.prod.covered; names(res)[12] = 'GasProduction_NewWellsCovered'
  # browser() # for testing
  return(res)
}


# Global sim
cons.elast.curve = function(p, p0, q0, e) {
  # Finds the quantity demanded (or supplied) at price p 
  # assuming a demand (or supply) function with constant
  # elasticity e around the point (q,p)=(q0, p0).
  q = q0*(p/p0)^e
  return(q)
}

excess.demand = function(p.adj.oil=1, p.adj.gas=1, r=0.02, weo.proj, price.path.dummy,
                         sim.start, sim.end=NULL,
                         base.spuds.adj = base.spuds.adj,
                         price.start = shock.date,
                         US.adjustment.factors = c(liq=1, gas=1), # calibrated once, so baseline sim 2019 production == actuals
                         Net.Demand.Adjustment = NULL, Demand.Share = 0.5, # split adjust ment 50/50 between demand and (row) supply
                         starting.inventories = c(Oil=6700*0.8, Gas=4000+3500), # oil in million barrels, gas in bcf
                         # Sources for inventory levels. These are estimates and likely lower bounds.
                         # https://www.csis.org/analysis/oil-inventory-challenge
                         # http://ir.eia.gov/ngs/ngs.html
                         # https://www.eia.gov/todayinenergy/detail.php?id=43235
                         # brent.premium = brent.wti.fut[, mean(spread.pct)],
                         base.prices = NULL,
                         spudregs.use = lm.regs, 
                         oil.dem.elast = -0.2, 
                         gas.dem.elast = -0.2, 
                         roy.fed.new.on.sim = NULL,
                         roy.fed.new.off.sim = NULL,
                         onshore.only=FALSE, # only applies to carbon adders and fed leasing. for RRs, specify directly
                         carb.add.fed.sim.oil = 0,
                         carb.add.fed.sim.gas = 0,
                         carb.add.nonfed.sim.oil = 0,
                         carb.add.nonfed.sim.gas = 0,
                         fed.leasing.ban = FALSE, ban.lag=0, # lag in ban taking effect, in years
                         carb.add.growth.rate = 0.02,
                         us.share.gas.marketed=0.884,
                         global.gas=TRUE) {
  # p.adj.oil & p.adj.gas are the price of oil at sim.start
  # r is the annual rate of interest
  #
  # Overwrite the WTI & HH values after start for the price path.
  # Keep historical prices from before today.
  # First, make a copy of price.path to make surewe don't modify by reference.
  price.path.dummy = copy(price.path.dummy)
  price.path.dummy[, c('wti','brent','hh') := NULL]
  price.path.dummy = copy(cbind(price.path.dummy, base.prices))
  
  # Adjust baseline prices according to p.adj.oil and p.adj.gas
  price.path.dummy[date>=price.start, wti := wti*p.adj.oil]
  price.path.dummy[date>=price.start, brent := brent*p.adj.oil]
  # assumes brent changes from baseline the same % as WTI. 
  # Effectively assumes that their spread is a constant %. 
  # This could be relaxed but it would create another degree of freedom 
  # (would need to solve for WTI, HH, and Brent prices)
  # price.path.dummy[date>=price.start, brent := wti*(1+brent.premium)]
  price.path.dummy[date>=price.start, hh := hh*p.adj.gas]
  
  global.sim = merge(weo.proj, price.path.dummy[,.(date, wti, brent, hh)], by='date', all.x=TRUE, all.y=FALSE)
  
  # Drop prior years so we're not overwriting data.
  global.sim = global.sim[date>=sim.start]
  
  # Calculate global prices.
  if (global.gas) global.sim[, row.gas.delivered.price.sim := hh + gas.spread]
  if (!global.gas) global.sim[, row.gas.delivered.price.sim := gas_price_intl.base]
  # assumes ROW suppliers get HH benchmark, and the spread is delivery cost. a simplification
  
  # Adjust demand based on the Net.Demand.Adjustment. 
  # This reflects changes in the forecast since COVID, and
  # ensures that S=D when prices are at the futures values.
  if (is.null(Net.Demand.Adjustment)) {
    Net.Demand.Adjustment = data.table(date=global.sim[, date], 
                                       Oil.Net.Demand.Adjustment = 0,
                                       Gas.Net.Demand.Adjustment = 0)
  }
  global.sim = merge(global.sim, Net.Demand.Adjustment, by='date')
  ### Demand
  # Split excess demand adjustment between demand and supply.
  # Demand.Share=0.5 assumes that half of the overshoot in demand projections,
  # relative to actual, is due to negative demand shocks (Covid), and the remaining half
  # is due to supply shocks (Saudi/Russia).
  # Here we SUBTRACT demand
  global.sim[, world_oil_cons := world_oil_cons - Oil.Net.Demand.Adjustment*Demand.Share]
  if (global.gas) global.sim[, world_gas_cons_bcfd := world_gas_cons_bcfd - Gas.Net.Demand.Adjustment*Demand.Share]
  if (!global.gas) global.sim[, us_gas_cons_bcfd := us_gas_cons_bcfd - Gas.Net.Demand.Adjustment*Demand.Share]
  
  # Calculate ROW supply and demand at these prices
  #   Oil Demand
  global.sim[, row.oil.demand :=  cons.elast.curve(p = brent, 
                                                   p0 = brent.base,
                                                   q0 = (world_oil_cons - us_oil_cons),
                                                   e=oil.dem.elast)]
  
  global.sim[, US.oil.demand :=  cons.elast.curve(p=brent, 
                                                  p0=brent.base,
                                                  q0=us_oil_cons,
                                                  e=oil.dem.elast)]
  
  global.sim[, Tot.Oil.Demand :=  row.oil.demand + US.oil.demand]
  # Note: for oil, assume that US demand keys off Brent, not WTI.
  # As a result, the total demand would be the same if we had calculated
  # a single global demand curve (since the prices and elasticities are
  # assumed to be the same). Further, if US demand keyed off WTI,
  # it wouldn't matter because of the assumption of
  # a constant WTI/Brent percentage spread.  A 10% increase in WTI would
  # lead to a 10% increase in Brent. So changing this shouldn't matter at all.
  #
  # However, doing this separately allows us 
  # to track US consumption and hence track US imports/exports.
  
  #   Gas Demand
  global.sim[, row.gas.demand :=  cons.elast.curve(p = row.gas.delivered.price.sim, 
                                                   p0 = gas_price_intl.base,
                                                   q0 = (world_gas_cons_bcfd - us_gas_cons_bcfd),
                                                   e=gas.dem.elast)]
  
  global.sim[, US.gas.demand :=  cons.elast.curve(p=hh, 
                                                  p0=hh.base,
                                                  q0=us_gas_cons_bcfd,
                                                  e=gas.dem.elast)]
  
  global.sim[, Tot.Gas.Demand :=  row.gas.demand + US.gas.demand]
  
  
  ### Supply
  # ROW Supply
  #   Oil Supply
  global.sim[, row.oil.supply := cons.elast.curve(p=brent, 
                                                  p0=brent.base,
                                                  q0=crude_prod_non_us_mbpd,
                                                  e=row_crude_elasticity)]
  #   Gas Supply. Based on international price (not HH)
  global.sim[, row.gas.supply := cons.elast.curve(p=row.gas.delivered.price.sim, 
                                                  p0=gas_price_intl.base, 
                                                  q0=gas_prod_non_us_bcfd,
                                                  e=row_gas_elasticity)]
  
  # Adjustment for supply side of the projection overshoot. Here we ADD to supply
  global.sim[, row.oil.supply := row.oil.supply + Oil.Net.Demand.Adjustment*(1-Demand.Share)]
  if (global.gas) global.sim[, row.gas.supply :=  row.gas.supply + Gas.Net.Demand.Adjustment*(1-Demand.Share)]
  # with a isolated US gas market need to add to US supply (not global). Do this after pulling in US sim below
  
  # US Supply
  # Adjust oil and gas production (existing and new IPs) based on US.adjustment.factors.
  # Use this to calibrate the model to account for "missing wells" and non-marketed gas production
  
  # Adjust total production from existing wells
  # liquids:
  prod.tot.liq.wide.adj = data.table(date=prod.tot.liq.wide[,date],
                                     prod.tot.liq.wide[,-'date']*US.adjustment.factors['liq'])
  # gas:
  prod.tot.gas.wide.adj = data.table(date=prod.tot.gas.wide[,date],
                                     prod.tot.gas.wide[,-'date']*US.adjustment.factors['gas'])
  # Adjust IPs for new wells
  # liquids:
  ip.liq.class.adj = ip.liq.class*US.adjustment.factors['liq']
  # gas:
  ip.gas.class.adj = ip.gas.class*US.adjustment.factors['gas']
  
  ## Cut off dates after sim.end
  if (is.null(sim.end)) sim.end = price.path.dummy[, max(date)]
  
  if (sim.end != price.path.dummy[, max(date)]) {
    global.sim = copy(global.sim[date<=sim.end,])
    price.path.dummy = copy(price.path.dummy[date<=sim.end,])
    prod.tot.liq.wide.adj = copy(prod.tot.liq.wide.adj[date<=sim.end,])
    prod.tot.gas.wide.adj = copy(prod.tot.gas.wide.adj[date<=sim.end,])
  }
  
  us.sim = my.sim(shock.date = shock.date,  # shock.date is date of policy change (royalty or carbon adder)
                  price.path.sim = price.path.dummy, 
                  lm.regs.sim = spudregs.use, 
                  ex.well.liq = prod.tot.liq.wide.adj, 
                  ex.well.gas = prod.tot.gas.wide.adj, 
                  stp.den = stp.densities.2018,
                  base.spuds = base.spuds.adj, roy.base = roy.base, sim.start = sim.start.date, 
                  profiles.pct.liq.class = profiles.pct.liq.class, profiles.pct.gas.class = profiles.pct.gas.class,
                  ip.liq.class = ip.liq.class.adj, ip.gas.class = ip.gas.class.adj, 
                  spudcounts.class = spudcounts.class,
                  use.spudreg.trend = FALSE,
                  price.floor=0.1,
                  roy.fed.new.on = roy.fed.new.on.sim, 
                  roy.fed.new.off = roy.fed.new.off.sim, 
                  onshore.only=onshore.only, 
                  fed.leasing.ban = fed.leasing.ban, ban.lag=ban.lag, # lag in ban taking effect, in years
                  carb.add.fed.oil = carb.add.fed.sim.oil,
                  carb.add.fed.gas = carb.add.fed.sim.gas,
                  carb.add.nonfed.oil = carb.add.nonfed.sim.oil,
                  carb.add.nonfed.gas = carb.add.nonfed.sim.gas,
                  carb.add.growth.rate = carb.add.growth.rate)
  
  global.sim = merge(global.sim, us.sim$OilProduction_AllWells[,.(date, US.oil.supply = liq_Total)], 
                     by='date', all.x=TRUE, all.y=FALSE)
  # Note: result of us.sim is GROSS gas production, whereas all other values from IEA
  # are MARKETED production. Here we adjust for the % marketed (88.4%).
  # The results inside us.sim will remain gross numbers however.
  global.sim = merge(global.sim, us.sim$GasProduction_AllWells[,.(date, US.gas.supply = us.share.gas.marketed*gas_Total)], 
                     by='date', all.x=TRUE, all.y=FALSE)
  
  if (!global.gas) global.sim[, US.gas.supply :=  US.gas.supply + Gas.Net.Demand.Adjustment*(1-Demand.Share)]
  
  global.sim[, Tot.Oil.Supply := US.oil.supply + row.oil.supply]
  global.sim[, Tot.Gas.Supply := US.gas.supply + row.gas.supply]
  
  global.sim[, Excess.Oil.Demand := Tot.Oil.Demand - Tot.Oil.Supply]
  if (global.gas) global.sim[, Excess.Gas.Demand := Tot.Gas.Demand - Tot.Gas.Supply]
  if (!global.gas) global.sim[, Excess.Gas.Demand := US.gas.demand - US.gas.supply]
  
  global.sim[, Oil.Inventories := starting.inventories['Oil'] - cumsum(Excess.Oil.Demand*365.25/12)]
  global.sim[, Gas.Inventories := starting.inventories['Gas'] - cumsum(Excess.Gas.Demand*365.25/12)]
  # browser() 
  return(list(global.sim = global.sim,
              us.sim = us.sim,
              Excess.Demand = c(Excess.Oil.Demand = global.sim[date>=price.start, mean(Excess.Oil.Demand, na.rm=TRUE)],
                                Excess.Gas.Demand = global.sim[date>=price.start, mean(Excess.Gas.Demand, na.rm=TRUE)])
  )
  )
  
}